
from torch import nn
import transformers
from transformers import AutoModelForSequenceClassification
import torch
import numpy as np
import random

class AttentionClassifier(nn.Module):
    """
    Neural/Contradiction detection attention layer
    """
    def __init__(self, dimensionality):
        super(AttentionClassifier,self).__init__()

        self.linear1 = torch.nn.Linear(dimensionality, dimensionality)
        self.linear2 = torch.nn.Linear(dimensionality, 1)
        self.tanh = torch.tanh
        self.linear3 = torch.nn.Linear(1, 1)
        self.sig = torch.nn.Sigmoid()
        self.softmax = torch.nn.Softmax(dim=1)
        self.eval = True
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def eval_(self):
        self.eval = True

    def train(self):
        self.eval = False

    def forward(
              self, 
              rep: torch.tensor, 
              logit: torch.tensor) -> dict:
        """
        Forward pass for ent/contradiction detection attention layers

        Args:
            rep: CLS representations of all spans (span no x CLS dimensions)
            logit: logit for all spans (span no x 1)
            consec_segments_in_span: which spans contain consec. segments

        Returns:
            output_val: dictionary, containing:
                'sent_output': probability of ent/cont. for sentence
                'att_weights': normalized attention weights
                'att_unnorm': unnormalzed attention weights
                'dropout': dropout mask applied 
        """

        # batch_size x seq_len x dim:
        val = self.linear1(rep)
        val = self.tanh(val)

        # batch_size x seq_len x 1:
        val = self.linear2(val)
        val = self.sig(val)

        sum_val = torch.sum(val)
        att_unnorm = val
        inv_sum_val = 1/sum_val

        # batch_size x seq_len x 1:
        att_weights = val*inv_sum_val

        dropout_mask = torch.ones(att_weights.shape).to(self.device)
        
        if not self.eval:
            
            att_weights = att_weights * dropout_mask
            att_unnorm = att_unnorm * dropout_mask

            new_sum = torch.sum(att_weights)
            att_weights = att_weights / new_sum

        #batch_size x dimensions
        updated_rep = torch.einsum('jk, jm -> k', [logit, att_weights])
        output_val = self.linear3(updated_rep)
        output_val = self.sig(output_val)

        # Preparing dictionary of outputs
        output_dict = {}
        output_dict['sent_output'] = output_val
        output_dict['att_weights'] = att_weights
        output_dict['att_unnorm'] = att_unnorm
        output_dict['dropout'] = dropout_mask

        return output_dict

class LogicModel(nn.Module):
    """
    Our Logic NLI model
    """
    def __init__(self, dimensionality, model_type):
        super(LogicModel,self).__init__()

        self.attention_cont = AttentionClassifier(
                dimensionality)
        self.attention_ent = AttentionClassifier(
                dimensionality)

        self.encoder = AutoModelForSequenceClassification.from_pretrained(
            model_type,
            output_attentions=True,
            output_hidden_states=True,
            num_labels=2)

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.model_type = model_type

    def forward(
            self, 
            batch: dict) -> dict:
        """
        We create batch input from different span masks, and pass through model
        
        Args:
            batch: batch input into model (input ids, mask and token_type_ids)
        
        Returns:
            outputs: sentence outputs for each class, with unnormalized 
                attention, which spans to supervise, and the labels
        """

        # We pass each span through BERT (with spans created from masking)
        all_spans_cls, all_spans_logits = None, None

        span_outputs_dict = self.encoder(
                batch['input_ids'],
                batch['attention_mask'],
                batch['token_type_ids'],
                return_dict=True)

        all_spans_cls = span_outputs_dict['hidden_states'][-1][:,0,:]
        all_spans_logits = span_outputs_dict['logits']
        
        # We set the labels for our entailment and condiction detection layers
        if batch['label'][0] == 0:
            ent_label=1
            cont_label=0
        elif batch['label'][0] == 1:
            ent_label=0
            cont_label=0
        else:
            ent_label=0
            cont_label=1

        # We find the outputs from the entailment/contradiction detection layers
        outputs = self.ent_and_cont_detection(
                    batch['label'][0],
                    all_spans_logits,
                    all_spans_cls,
                    ent_label=ent_label,
                    cont_label=cont_label)

        return outputs


    def ent_and_cont_detection(
        self,
        true_label,
        all_spans_logits: torch.tensor,
        all_spans_cls: torch.tensor,
        ent_label: int,
        cont_label: int,
        ) -> dict:
        """
        Apply the entailment and contradiction detection attention layers

        Args:
            all_spans_logits: logits for each span (for both classes)
            all_spans_cls: cls representation for each span
            ent_label: entailment detection label
            cont_label: cont detection label

        Returns:
            output_dict: dict of outputs for NLI sentence pairs
        """

        ent_output = self.attention_ent(
                all_spans_cls, 
                all_spans_logits[:,0].unsqueeze(1))

        ent_output['label'] = torch.tensor(
                [ent_label]).to(self.device)

        cont_output = self.attention_cont(
                all_spans_cls, 
                all_spans_logits[:,1].unsqueeze(1))

        cont_output['label'] = torch.tensor(
                [cont_label]).to(self.device)


        output_dict = {
                'true_label': true_label,
                'ent': ent_output,
                'cont': cont_output}

        return output_dict
